# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from examples.tensorflow.common.object_detection.ops import nms
from examples.tensorflow.common.object_detection.utils import box_utils


def multilevel_propose_rois(
    rpn_boxes,
    rpn_scores,
    anchor_boxes,
    image_shape,
    rpn_pre_nms_top_k=2000,
    rpn_post_nms_top_k=1000,
    rpn_nms_threshold=0.7,
    rpn_score_threshold=0.0,
    rpn_min_size_threshold=0.0,
    decode_boxes=True,
    clip_boxes=True,
    use_batched_nms=False,
    apply_sigmoid_to_score=True,
):
    """Proposes RoIs given a group of candidates from different FPN levels.

    The following describes the steps:
        1. For each individual level:
            a. Apply sigmoid transform if specified.
            b. Decode boxes if specified.
            c. Clip boxes if specified.
            d. Filter small boxes and those fall outside image if specified.
            e. Apply pre-NMS filtering including pre-NMS top k and score thresholding.
            f. Apply NMS.
        2. Aggregate post-NMS boxes from each level.
        3. Apply an overall top k to generate the final selected RoIs.

    Args:
        rpn_boxes: a dict with keys representing FPN levels and values representing
            box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
        rpn_scores: a dict with keys representing FPN levels and values representing
            logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
        anchor_boxes: a dict with keys representing FPN levels and values
            representing anchor box tensors of shape [batch_size, feature_h,
            feature_w, num_anchors * 4].
        image_shape: a tensor of shape [batch_size, 2] where the last dimension are
            [height, width] of the scaled image.
        rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
            keep before applying NMS. Default: 2000.
        rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
            keep after applying NMS. Default: 1000.
        rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
            used for NMS. If 0.0, no NMS is applied. Default: 0.7.
        rpn_score_threshold: a float between 0 and 1 representing the minimal box
            score to keep before applying NMS. This is often used as a pre-filtering
            step for better performance. If 0, no filtering is applied. Default: 0.
        rpn_min_size_threshold: a float representing the minimal box size in each
            side (w.r.t. the scaled image) to keep before applying NMS. This is often
            used as a pre-filtering step for better performance. If 0, no filtering is
            applied. Default: 0.
        decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
            using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
            `anchor_boxes`. Default: True.
        clip_boxes: a boolean indicating whether boxes are first clipped to the
            scaled image size before appliying NMS. If False, no clipping is applied
            and `image_shape` is ignored. Default: True.
        use_batched_nms: a boolean indicating whether NMS is applied in batch using
            `tf.image.combined_non_max_suppression`. Currently only available in
            CPU/GPU. Default: False.
        apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
            `rpn_scores` before applying NMS. Default: True.

    Returns:
        selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
            representing the box coordinates of the selected proposals w.r.t. the
            scaled image.
        selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k, 1],
            representing the scores of the selected proposals.
    """

    with tf.name_scope("multilevel_propose_rois"):
        rois = []
        roi_scores = []
        image_shape = tf.expand_dims(image_shape, axis=1)
        for level in sorted(rpn_scores.keys()):
            with tf.name_scope("level_{}".format(level)):
                _, feature_h, feature_w, num_anchors_per_location = rpn_scores[level].get_shape().as_list()

                num_boxes = feature_h * feature_w * num_anchors_per_location
                this_level_scores = tf.reshape(rpn_scores[level], [-1, num_boxes])
                this_level_boxes = tf.reshape(rpn_boxes[level], [-1, num_boxes, 4])
                this_level_anchors = tf.cast(
                    tf.reshape(anchor_boxes[int(level)], [-1, num_boxes, 4]), this_level_scores.dtype
                )

                if apply_sigmoid_to_score:
                    this_level_scores = tf.sigmoid(this_level_scores)

                if decode_boxes:
                    this_level_boxes = box_utils.decode_boxes(this_level_boxes, this_level_anchors)
                if clip_boxes:
                    this_level_boxes = box_utils.clip_boxes(this_level_boxes, image_shape)

                if rpn_min_size_threshold > 0.0:
                    this_level_boxes, this_level_scores = box_utils.filter_boxes(
                        this_level_boxes, this_level_scores, image_shape, rpn_min_size_threshold
                    )

                this_level_pre_nms_top_k = min(num_boxes, rpn_pre_nms_top_k)
                this_level_post_nms_top_k = min(num_boxes, rpn_post_nms_top_k)
                if rpn_nms_threshold > 0.0:
                    if use_batched_nms:
                        this_level_rois, this_level_roi_scores, _, _ = tf.image.combined_non_max_suppression(
                            tf.expand_dims(this_level_boxes, axis=2),
                            tf.expand_dims(this_level_scores, axis=-1),
                            max_output_size_per_class=this_level_pre_nms_top_k,
                            max_total_size=this_level_post_nms_top_k,
                            iou_threshold=rpn_nms_threshold,
                            score_threshold=rpn_score_threshold,
                            pad_per_class=False,
                            clip_boxes=False,
                        )
                    else:
                        if rpn_score_threshold > 0.0:
                            this_level_boxes, this_level_scores = box_utils.filter_boxes_by_scores(
                                this_level_boxes, this_level_scores, rpn_score_threshold
                            )

                        this_level_boxes, this_level_scores = box_utils.top_k_boxes(
                            this_level_boxes, this_level_scores, k=this_level_pre_nms_top_k
                        )

                        this_level_roi_scores, this_level_rois = nms.sorted_non_max_suppression_padded(
                            this_level_scores,
                            this_level_boxes,
                            max_output_size=this_level_post_nms_top_k,
                            iou_threshold=rpn_nms_threshold,
                        )
                else:
                    this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
                        this_level_rois, this_level_scores, k=this_level_post_nms_top_k
                    )

                rois.append(this_level_rois)
                roi_scores.append(this_level_roi_scores)

        all_rois = tf.concat(rois, 1)
        all_roi_scores = tf.concat(roi_scores, 1)

        with tf.name_scope("top_k_rois"):
            _, num_valid_rois = all_roi_scores.get_shape().as_list()
            overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)

            selected_rois, selected_roi_scores = box_utils.top_k_boxes(all_rois, all_roi_scores, k=overall_top_k)

        return selected_rois, selected_roi_scores


class ROIGenerator:
    """Proposes RoIs for the second stage processing."""

    def __init__(self, params):
        self._rpn_pre_nms_top_k = params.rpn_pre_nms_top_k
        self._rpn_post_nms_top_k = params.rpn_post_nms_top_k
        self._rpn_nms_threshold = params.rpn_nms_threshold
        self._rpn_score_threshold = params.rpn_score_threshold
        self._rpn_min_size_threshold = params.rpn_min_size_threshold
        self._test_rpn_pre_nms_top_k = params.test_rpn_pre_nms_top_k
        self._test_rpn_post_nms_top_k = params.test_rpn_post_nms_top_k
        self._test_rpn_nms_threshold = params.test_rpn_nms_threshold
        self._test_rpn_score_threshold = params.test_rpn_score_threshold
        self._test_rpn_min_size_threshold = params.test_rpn_min_size_threshold
        self._use_batched_nms = params.use_batched_nms

    def __call__(self, boxes, scores, anchor_boxes, image_shape, is_training):
        """Generates RoI proposals.

        Args:
            boxes: a dict with keys representing FPN levels and values representing
                box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
            scores: a dict with keys representing FPN levels and values representing
                logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
            anchor_boxes: a dict with keys representing FPN levels and values
                representing anchor box tensors of shape [batch_size, feature_h,
                feature_w, num_anchors * 4].
            image_shape: a tensor of shape [batch_size, 2] where the last dimension
                are [height, width] of the scaled image.
            is_training: a bool indicating whether it is in training or inference
                mode.

        Returns:
            proposed_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
                representing the box coordinates of the proposed RoIs w.r.t. the
                scaled image.
            proposed_roi_scores: a tensor of shape
                [batch_size, rpn_post_nms_top_k, 1], representing the scores of the
                proposed RoIs.
        """
        proposed_rois, proposed_roi_scores = multilevel_propose_rois(
            boxes,
            scores,
            anchor_boxes,
            image_shape,
            rpn_pre_nms_top_k=(self._rpn_pre_nms_top_k if is_training else self._test_rpn_pre_nms_top_k),
            rpn_post_nms_top_k=(self._rpn_post_nms_top_k if is_training else self._test_rpn_post_nms_top_k),
            rpn_nms_threshold=(self._rpn_nms_threshold if is_training else self._test_rpn_nms_threshold),
            rpn_score_threshold=(self._rpn_score_threshold if is_training else self._test_rpn_score_threshold),
            rpn_min_size_threshold=(self._rpn_min_size_threshold if is_training else self._test_rpn_min_size_threshold),
            decode_boxes=True,
            clip_boxes=True,
            use_batched_nms=self._use_batched_nms,
            apply_sigmoid_to_score=True,
        )

        return proposed_rois, proposed_roi_scores
